"""
Phase 7: Final Tests Before Paper
Development Threshold Paper

1. Firth (penalized) logit for small-sample robustness
2. Fell-back country narratives
3. Mediation: Z₁ → savings → CA → crossing
4. Gulf/resource exclusion robustness
5. China analysis
"""

import pandas as pd
import numpy as np
from scipy import stats
import statsmodels.api as sm
import sys
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')
from model import PanelGLS

# ═══════════════════════════════════════════════════════════════════════
# Load data
# ═══════════════════════════════════════════════════════════════════════
full = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
panel = full[full['year'] <= 2024].copy()

try:
    rents = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/raw/wdi_resource_rents.csv')
    panel = panel.merge(rents[['iso3', 'year', 'resource_rents_gdp']], on=['iso3', 'year'], how='left')
except:
    panel['resource_rents_gdp'] = np.nan

cross_df = pd.read_csv('/mnt/c/demographics_capital_flows/development_threshold/data/crossing_data.csv')
LOWER = 9000
UPPER = 25000

zone_entrants = cross_df[cross_df['cross_9k'].notna() & (cross_df['status'] != 'Always above')].copy()

# ═══════════════════════════════════════════════════════════════════════
# 1. FIRTH (PENALIZED) LOGIT
# ═══════════════════════════════════════════════════════════════════════
print("=" * 70)
print("1. FIRTH PENALIZED LOGIT (SMALL-SAMPLE ROBUSTNESS)")
print("=" * 70)

def firth_logit(y, X, max_iter=100, tol=1e-6):
    """
    Firth penalized logistic regression.
    Adds Jeffreys prior (penalty = 0.5 * log det(I)) to reduce small-sample bias.
    """
    n, k = X.shape
    beta = np.zeros(k)

    for iteration in range(max_iter):
        pi = 1 / (1 + np.exp(-X @ beta))
        pi = np.clip(pi, 1e-10, 1 - 1e-10)

        W = np.diag(pi * (1 - pi))
        XWX = X.T @ W @ X

        try:
            XWX_inv = np.linalg.inv(XWX)
        except np.linalg.LinAlgError:
            print("  Firth: singular matrix, returning standard logit")
            return None

        # Hat matrix diagonal
        H = np.diag(X @ XWX_inv @ X.T @ W)

        # Modified score with Firth correction
        U = X.T @ (y - pi + H * (0.5 - pi))

        delta = XWX_inv @ U
        beta += delta

        if np.max(np.abs(delta)) < tol:
            break

    # Standard errors from Fisher information
    pi = 1 / (1 + np.exp(-X @ beta))
    pi = np.clip(pi, 1e-10, 1 - 1e-10)
    W = np.diag(pi * (1 - pi))
    try:
        cov = np.linalg.inv(X.T @ W @ X)
        se = np.sqrt(np.diag(cov))
    except:
        se = np.full(k, np.nan)

    z = beta / se
    pvalues = 2 * (1 - stats.norm.cdf(np.abs(z)))

    return {'beta': beta, 'se': se, 'pvalues': pvalues, 'n_iter': iteration + 1}


# Run Firth on key specifications
specs = {
    'M1: Z₁ only': ['Z_1_entry'],
    'M2: Z₁ + GDP': ['Z_1_entry', 'gdp_pc_ppp_entry'],
    'M4: Z₁ + GDP + KAOPEN + trade': ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry', 'trade_openness_entry'],
}

for label, xvars in specs.items():
    available = [v for v in xvars if v in zone_entrants.columns]
    sample = zone_entrants[['exited_above'] + available].dropna()
    if len(sample) < 15:
        continue

    y = sample['exited_above'].values
    X = sm.add_constant(sample[available]).values
    var_names = ['const'] + available

    # Standard logit
    std_model = sm.Logit(y, X).fit(disp=0)

    # Firth logit
    firth = firth_logit(y, X)

    print(f"\n  {label} (n={len(sample)}):")
    print(f"  {'Variable':30s} {'Std coef':>10s} {'Std p':>8s} {'Firth coef':>12s} {'Firth p':>8s}")
    print("  " + "-" * 72)
    for i, var in enumerate(var_names):
        if var == 'const':
            continue
        std_sig = '***' if std_model.pvalues[i] < 0.01 else '**' if std_model.pvalues[i] < 0.05 else '*' if std_model.pvalues[i] < 0.1 else ''
        if firth is not None:
            f_sig = '***' if firth['pvalues'][i] < 0.01 else '**' if firth['pvalues'][i] < 0.05 else '*' if firth['pvalues'][i] < 0.1 else ''
            print(f"  {var:30s} {std_model.params[i]:10.4f}{std_sig:3s} {std_model.pvalues[i]:8.4f} "
                  f"{firth['beta'][i]:12.4f}{f_sig:3s} {firth['pvalues'][i]:8.4f}")
        else:
            print(f"  {var:30s} {std_model.params[i]:10.4f}{std_sig:3s} {std_model.pvalues[i]:8.4f}    Firth failed")


# ═══════════════════════════════════════════════════════════════════════
# 2. FELL-BACK COUNTRIES: NARRATIVE
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2. COUNTRIES THAT FELL BACK BELOW $9k")
print("=" * 70)

# Find countries that entered the zone then fell back
fell_back = []
for iso in panel['iso3'].unique():
    c = panel[(panel['iso3'] == iso) & (panel['gdp_pc_ppp'].notna())].sort_values('year')
    if len(c) < 5:
        continue

    ever_in = c[(c['gdp_pc_ppp'] >= LOWER) & (c['gdp_pc_ppp'] <= UPPER)]
    if len(ever_in) == 0:
        continue

    entry_year = ever_in['year'].min()
    after_entry = c[c['year'] > entry_year]
    below_after = after_entry[after_entry['gdp_pc_ppp'] < LOWER]

    if len(below_after) > 0:
        peak_gdp = c['gdp_pc_ppp'].max()
        peak_year = c.loc[c['gdp_pc_ppp'].idxmax(), 'year']
        trough_gdp = after_entry['gdp_pc_ppp'].min()
        trough_year = after_entry.loc[after_entry['gdp_pc_ppp'].idxmin(), 'year']
        last_gdp = c.iloc[-1]['gdp_pc_ppp']
        last_year = c.iloc[-1]['year']

        # Resource rents
        rents_avg = panel[(panel['iso3'] == iso) & (panel['resource_rents_gdp'].notna())]['resource_rents_gdp'].mean()

        fell_back.append({
            'iso3': iso,
            'entry_year': entry_year,
            'peak_gdp': peak_gdp, 'peak_year': peak_year,
            'trough_gdp': trough_gdp, 'trough_year': trough_year,
            'last_gdp': last_gdp, 'last_year': int(last_year),
            'decline_pct': (trough_gdp - peak_gdp) / peak_gdp * 100,
            'resource_rents_avg': rents_avg,
        })

fb_df = pd.DataFrame(fell_back).sort_values('decline_pct')

print(f"\n{len(fb_df)} countries fell back below $9k after entering the zone:\n")
print(f"{'Country':6s} {'Entry':>6s} {'Peak':>14s} {'Trough':>14s} {'Current':>14s} {'Decline':>8s} {'Rents':>6s}")
print("-" * 75)
for _, row in fb_df.iterrows():
    rents_str = f"{row['resource_rents_avg']:.1f}%" if pd.notna(row['resource_rents_avg']) else '—'
    print(f"{row['iso3']:6s} {int(row['entry_year']):>6d} "
          f"${row['peak_gdp']:>9,.0f}({int(row['peak_year'])}) "
          f"${row['trough_gdp']:>9,.0f}({int(row['trough_year'])}) "
          f"${row['last_gdp']:>9,.0f}({row['last_year']}) "
          f"{row['decline_pct']:>7.0f}% {rents_str:>6s}")

# Classify reasons
print("\nNarrative classification:")
conflict_states = ['SYR', 'IRQ', 'LBY', 'YEM', 'UKR', 'AFG']
resource_collapse = ['VEN', 'GAB', 'COG', 'AGO', 'GNQ', 'TKM']
for _, row in fb_df.iterrows():
    iso = row['iso3']
    if iso in conflict_states:
        reason = "CONFLICT"
    elif iso in resource_collapse or (pd.notna(row['resource_rents_avg']) and row['resource_rents_avg'] > 10):
        reason = "RESOURCE COLLAPSE"
    elif row['decline_pct'] < -40:
        reason = "SYSTEMIC COLLAPSE"
    else:
        reason = "STAGNATION/OTHER"
    print(f"  {iso}: {reason} (peak ${row['peak_gdp']:,.0f} → trough ${row['trough_gdp']:,.0f})")

fb_df.to_csv('/mnt/c/demographics_capital_flows/development_threshold/output/tables/table17_fell_back.csv', index=False)

# ═══════════════════════════════════════════════════════════════════════
# 3. MEDIATION: Z₁ → SAVINGS → CA → CROSSING
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("3. MEDIATION ANALYSIS: Z₁ → SAVINGS → CA → CROSSING")
print("=" * 70)

# Step 1: Z₁ → savings (in zone, panel)
zone_panel = panel[(panel['gdp_pc_ppp'] >= LOWER) & (panel['gdp_pc_ppp'] <= UPPER)].copy()
controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen']

print("\nStep 1: Z₁ → Savings (PanelGLS, in zone)")
sav_vars = ['Z_1'] + controls
sav_sample = zone_panel[['iso3', 'year', 'gross_savings_gdp'] + sav_vars].dropna()
if len(sav_sample) >= 50:
    gls1 = PanelGLS()
    gls1.fit(sav_sample['gross_savings_gdp'], sav_sample[sav_vars], sav_sample['iso3'], sav_sample['year'])
    z1_on_sav = gls1.beta[0]
    z1_on_sav_p = gls1.pvalues[0]
    sig = '***' if z1_on_sav_p < 0.01 else '**' if z1_on_sav_p < 0.05 else '*' if z1_on_sav_p < 0.1 else ''
    print(f"  Z₁ → Savings: β={z1_on_sav:.3f}, p={z1_on_sav_p:.4f}{sig}")

# Step 2: Savings → CA (in zone)
print("\nStep 2: Savings → CA (PanelGLS, in zone)")
ca_vars = ['gross_savings_gdp'] + controls
ca_sample = zone_panel[['iso3', 'year', 'ca_gdp'] + ca_vars].dropna()
if len(ca_sample) >= 50:
    gls2 = PanelGLS()
    gls2.fit(ca_sample['ca_gdp'], ca_sample[ca_vars], ca_sample['iso3'], ca_sample['year'])
    sav_on_ca = gls2.beta[0]
    sav_on_ca_p = gls2.pvalues[0]
    sig = '***' if sav_on_ca_p < 0.01 else '**' if sav_on_ca_p < 0.05 else '*' if sav_on_ca_p < 0.1 else ''
    print(f"  Savings → CA: β={sav_on_ca:.3f}, p={sav_on_ca_p:.4f}{sig}")

# Step 3: Z₁ → CA direct vs with savings control
print("\nStep 3: Z₁ → CA direct vs controlling for savings")
# Direct
ca_direct_vars = ['Z_1'] + controls
ca_dir_sample = zone_panel[['iso3', 'year', 'ca_gdp'] + ca_direct_vars].dropna()
if len(ca_dir_sample) >= 50:
    gls3a = PanelGLS()
    gls3a.fit(ca_dir_sample['ca_gdp'], ca_dir_sample[ca_direct_vars], ca_dir_sample['iso3'], ca_dir_sample['year'])
    z1_direct = gls3a.beta[0]
    z1_direct_p = gls3a.pvalues[0]
    sig = '***' if z1_direct_p < 0.01 else '**' if z1_direct_p < 0.05 else '*' if z1_direct_p < 0.1 else ''
    print(f"  Z₁ → CA (direct): β={z1_direct:.3f}, p={z1_direct_p:.4f}{sig}")

# With savings
ca_med_vars = ['Z_1', 'gross_savings_gdp'] + controls
ca_med_sample = zone_panel[['iso3', 'year', 'ca_gdp'] + ca_med_vars].dropna()
if len(ca_med_sample) >= 50:
    gls3b = PanelGLS()
    gls3b.fit(ca_med_sample['ca_gdp'], ca_med_sample[ca_med_vars], ca_med_sample['iso3'], ca_med_sample['year'])
    z1_mediated = gls3b.beta[0]
    z1_mediated_p = gls3b.pvalues[0]
    sav_in_ca = gls3b.beta[1]
    sig = '***' if z1_mediated_p < 0.01 else '**' if z1_mediated_p < 0.05 else '*' if z1_mediated_p < 0.1 else ''
    print(f"  Z₁ → CA (with savings): β={z1_mediated:.3f}, p={z1_mediated_p:.4f}{sig}")
    print(f"  Savings → CA (with Z₁): β={sav_in_ca:.3f}")

    # Mediation proportion
    if abs(z1_direct) > 0:
        indirect = z1_direct - z1_mediated
        mediation_pct = indirect / z1_direct * 100
        print(f"\n  Direct effect: {z1_direct:.3f}")
        print(f"  Effect controlling for savings: {z1_mediated:.3f}")
        print(f"  Indirect (mediated) effect: {indirect:.3f}")
        print(f"  Mediation proportion: {mediation_pct:.1f}%")

# Step 4: Cross-sectional mediation — does avg savings predict crossing?
print("\nStep 4: Cross-sectional — average savings during transit")
for _, row in zone_entrants.iterrows():
    iso = row['iso3']
    entry = row['cross_9k']
    c = panel[(panel['iso3'] == iso) & (panel['year'] >= entry) & (panel['gross_savings_gdp'].notna())]
    zone_entrants.loc[zone_entrants['iso3'] == iso, 'avg_savings'] = c['gross_savings_gdp'].mean() if len(c) > 0 else np.nan
    zone_entrants.loc[zone_entrants['iso3'] == iso, 'avg_ca'] = c['ca_gdp'].mean() if len(c) > 0 else np.nan

c_sav = zone_entrants[zone_entrants['exited_above'] == 1]['avg_savings'].dropna()
nc_sav = zone_entrants[zone_entrants['exited_above'] == 0]['avg_savings'].dropna()
if len(c_sav) >= 3 and len(nc_sav) >= 3:
    t, p = stats.ttest_ind(c_sav, nc_sav, equal_var=False)
    print(f"  Avg savings during transit: crossers={c_sav.mean():.1f}%, non-crossers={nc_sav.mean():.1f}%, p={p:.4f}")

c_ca = zone_entrants[zone_entrants['exited_above'] == 1]['avg_ca'].dropna()
nc_ca = zone_entrants[zone_entrants['exited_above'] == 0]['avg_ca'].dropna()
if len(c_ca) >= 3 and len(nc_ca) >= 3:
    t, p = stats.ttest_ind(c_ca, nc_ca, equal_var=False)
    print(f"  Avg CA during transit: crossers={c_ca.mean():.1f}%, non-crossers={nc_ca.mean():.1f}%, p={p:.4f}")

# ═══════════════════════════════════════════════════════════════════════
# 4. GULF / RESOURCE EXCLUSION
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("4. GULF AND RESOURCE EXCLUSION ROBUSTNESS")
print("=" * 70)

# Gulf states in our data
gulf = ['SAU', 'ARE', 'KWT', 'QAT', 'BHR', 'OMN']
print("Gulf states classification:")
for iso in gulf:
    if iso in cross_df['iso3'].values:
        row = cross_df[cross_df['iso3'] == iso].iloc[0]
        print(f"  {iso}: {row['status']}")
    else:
        print(f"  {iso}: not in crossing data")

# Resource-dependent crossers
resource_crossers = zone_entrants[
    (zone_entrants['exited_above'] == 1) &
    (zone_entrants['resource_rents_gdp_entry'].fillna(0) >= 10)
]
print(f"\nResource-dependent crossers (rents ≥ 10% at entry): {len(resource_crossers)}")
for _, row in resource_crossers.iterrows():
    rents = row.get('resource_rents_gdp_entry', np.nan)
    rents_str = f"{rents:.1f}%" if pd.notna(rents) else '—'
    print(f"  {row['iso3']}: rents={rents_str}")

# Logit excluding resource-dependent countries
non_resource = zone_entrants[zone_entrants['resource_rents_gdp_entry'].fillna(0) < 10]
logit_vars = ['Z_1_entry', 'gdp_pc_ppp_entry']
sample_nr = non_resource[['exited_above'] + logit_vars].dropna()

if len(sample_nr) >= 15:
    model_nr = sm.Logit(sample_nr['exited_above'],
                        sm.add_constant(sample_nr[logit_vars])).fit(disp=0)
    print(f"\nLogit excluding resource-dependent (≥10%) countries (n={len(sample_nr)}):")
    for var in logit_vars:
        sig = '***' if model_nr.pvalues[var] < 0.01 else '**' if model_nr.pvalues[var] < 0.05 else '*' if model_nr.pvalues[var] < 0.1 else ''
        print(f"  {var}: coef={model_nr.params[var]:.4f}, p={model_nr.pvalues[var]:.4f}{sig}")
    print(f"  pseudo-R²={model_nr.prsquared:.3f}")

# Even stricter: exclude ANY country with avg rents > 5%
strict_non_resource = zone_entrants.copy()
for idx, row in strict_non_resource.iterrows():
    iso = row['iso3']
    avg_rents = panel[(panel['iso3'] == iso) & (panel['resource_rents_gdp'].notna())]['resource_rents_gdp'].mean()
    strict_non_resource.loc[idx, 'avg_rents'] = avg_rents

strict_nr = strict_non_resource[strict_non_resource['avg_rents'].fillna(0) < 5]
sample_strict = strict_nr[['exited_above'] + logit_vars].dropna()

if len(sample_strict) >= 15:
    model_strict = sm.Logit(sample_strict['exited_above'],
                            sm.add_constant(sample_strict[logit_vars])).fit(disp=0)
    print(f"\nStrict exclusion (avg rents < 5%) (n={len(sample_strict)}):")
    for var in logit_vars:
        sig = '***' if model_strict.pvalues[var] < 0.01 else '**' if model_strict.pvalues[var] < 0.05 else '*' if model_strict.pvalues[var] < 0.1 else ''
        print(f"  {var}: coef={model_strict.params[var]:.4f}, p={model_strict.pvalues[var]:.4f}{sig}")
    print(f"  pseudo-R²={model_strict.prsquared:.3f}")

# ═══════════════════════════════════════════════════════════════════════
# 5. CHINA BOX
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("5. CHINA: THE PARADOX OF HIGH PROBABILITY AND CLOSED WINDOW")
print("=" * 70)

chn = panel[panel['iso3'] == 'CHN'].sort_values('year')
chn_gdp = chn[chn['gdp_pc_ppp'].notna()]

print(f"\nChina GDP/cap trajectory:")
for yr in [1990, 2000, 2010, 2015, 2020, 2024]:
    row = chn_gdp[chn_gdp['year'] == yr]
    if len(row) > 0:
        gdp = row['gdp_pc_ppp'].values[0]
        z1 = row['Z_1'].values[0]
        oadr = row['old_dep'].values[0] * 100 if pd.notna(row['old_dep'].values[0]) else np.nan
        print(f"  {yr}: GDP/cap=${gdp:,.0f}, Z₁={z1:.2f}, OADR={oadr:.1f}%")

# Projections
chn_proj = full[full['iso3'] == 'CHN']
for yr in [2030, 2040, 2050]:
    row = chn_proj[chn_proj['year'] == yr]
    if len(row) > 0:
        z1 = row['Z_1'].values[0]
        oadr = row['old_dep'].values[0] * 100 if pd.notna(row['old_dep'].values[0]) else np.nan
        print(f"  {yr}: Z₁={z1:.2f}, OADR={oadr:.1f}% (projected)")

# When did China enter the zone?
chn_entry = chn_gdp[chn_gdp['gdp_pc_ppp'] >= LOWER]['year'].min()
chn_z1_at_entry = chn[chn['year'] == chn_entry]['Z_1'].values[0] if pd.notna(chn_entry) else np.nan
print(f"\nChina entered $9k zone: {int(chn_entry) if pd.notna(chn_entry) else '?'}")
print(f"Z₁ at entry: {chn_z1_at_entry:.2f}")

# Historical crossers' Z₁ at entry vs China
crossers_z1 = zone_entrants[zone_entrants['exited_above'] == 1]['Z_1_entry'].dropna()
print(f"\nHistorical crossers' Z₁ at entry: mean={crossers_z1.mean():.3f}, p25={crossers_z1.quantile(0.25):.3f}")
print(f"China's Z₁ at $9k entry: {chn_z1_at_entry:.3f}")
print(f"China's current Z₁: {chn[chn['year']==2024]['Z_1'].values[0]:.3f}")

# Crossers' Z₁ at $25k exit
crossers_z1_exit = cross_df[cross_df['exited_above'] == 1]['Z_1_exit'].dropna()
print(f"Historical crossers' Z₁ at $25k exit: mean={crossers_z1_exit.mean():.3f}")
print(f"China's projected Z₁ at 2030: {chn_proj[chn_proj['year']==2030]['Z_1'].values[0]:.2f}")

# Growth needed
current_gdp = chn_gdp[chn_gdp['year'] == 2024]['gdp_pc_ppp'].values[0]
gap_pct = (UPPER - current_gdp) / current_gdp * 100
print(f"\nChina needs: +{gap_pct:.0f}% GDP/cap growth to reach $25k")
print(f"  At 5% growth: ~{np.log(UPPER/current_gdp)/np.log(1.05):.0f} years")
print(f"  At 3% growth: ~{np.log(UPPER/current_gdp)/np.log(1.03):.0f} years")
print(f"  At 2% growth: ~{np.log(UPPER/current_gdp)/np.log(1.02):.0f} years")

# Savings trajectory — is China's savings rate declining?
chn_recent = chn[(chn['year'] >= 2010) & (chn['gross_savings_gdp'].notna())]
if len(chn_recent) >= 5:
    from scipy.stats import linregress
    slope, intercept, r, p, se = linregress(chn_recent['year'], chn_recent['gross_savings_gdp'])
    print(f"\nChina savings trend (2010-2024): {slope:+.2f}pp/year (p={p:.4f})")
    print(f"  2010: {chn_recent.iloc[0]['gross_savings_gdp']:.1f}%")
    latest = chn_recent.iloc[-1]
    print(f"  {int(latest['year'])}: {latest['gross_savings_gdp']:.1f}%")

print("\n✓ Phase 7 complete.")
